{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# coding=utf-8\n",
    "from typing import Dict\n",
    "import time \n",
    "import pandas as pd \n",
    "\n",
    "import torch\n",
    "from datasets import Dataset, load_dataset\n",
    "from transformers import PreTrainedTokenizerFast, Seq2SeqTrainer, DataCollatorForSeq2Seq,Seq2SeqTrainingArguments\n",
    "from transformers.generation.configuration_utils import GenerationConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -2]\n",
    "root = '/'.join(root)\n",
    "if root not in sys.path:\n",
    "     sys.path.append(root)\n",
    "\n",
    "from model.chat_model import TextToTextModel\n",
    "from config import SFTconfig, InferConfig, T5ModelConfig\n",
    "from utils.functions import get_T5_config\n",
    "\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(file: str, split: str, encode_fn: callable, encode_args: dict,  cache_dir: str='.cache') -> Dataset:\n",
    "    \"\"\"\n",
    "    Load a dataset\n",
    "    \"\"\"\n",
    "    dataset = load_dataset('json', data_files=file,  split=split, cache_dir=cache_dir)\n",
    "\n",
    "    def merge_prompt_and_responses(sample: dict) -> Dict[str, str]:\n",
    "        # add an eos token note that end of sentence, using in generate.\n",
    "        prompt = encode_fn(f\"{sample['prompt']}[EOS]\", **encode_args)\n",
    "        response = encode_fn(f\"{sample['response']}[EOS]\", **encode_args)\n",
    "        return {\n",
    "            'input_ids': prompt.input_ids,\n",
    "            'input_mask': prompt.attention_mask,\n",
    "            'labels': response.input_ids,\n",
    "        }\n",
    "\n",
    "    dataset = dataset.map(merge_prompt_and_responses)\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sft_train(config: SFTconfig) -> None:\n",
    "\n",
    "    # step 1. 加载tokenizer\n",
    "    tokenizer = PreTrainedTokenizerFast.from_pretrained(config.tokenizer_dir)\n",
    "    \n",
    "    # step 2. 加载预训练模型\n",
    "    model = None\n",
    "    if os.path.isdir(config.finetune_from_ckp_file):\n",
    "        # 传入文件夹则 from_pretrained\n",
    "        model = TextToTextModel.from_pretrained(config.finetune_from_ckp_file)\n",
    "    else:\n",
    "        # load_state_dict\n",
    "        t5_config = get_T5_config(T5ModelConfig(), vocab_size=len(tokenizer), decoder_start_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)\n",
    "        model = TextToTextModel(t5_config)\n",
    "        model.load_state_dict(torch.load(config.finetune_from_ckp_file, map_location='cpu')) # set cpu for no exception\n",
    "        \n",
    "    # Step 4: Load the dataset\n",
    "    encode_args = {\n",
    "        'truncation': False,\n",
    "        'padding': 'max_length',\n",
    "    }\n",
    "\n",
    "    dataset = get_dataset(file=config.sft_train_file, encode_fn=tokenizer.encode_plus, encode_args=encode_args, split=\"train\")\n",
    "\n",
    "    # Step 5: Define the training arguments\n",
    "    # T5属于sequence to sequence模型，故要使用Seq2SeqTrainingArguments、DataCollatorForSeq2Seq、Seq2SeqTrainer\n",
    "    # huggingface官网的sft工具适用于language model/LM模型\n",
    "    generation_config = GenerationConfig()\n",
    "    generation_config.remove_invalid_values = True\n",
    "    generation_config.eos_token_id = tokenizer.eos_token_id\n",
    "    generation_config.pad_token_id = tokenizer.pad_token_id\n",
    "    generation_config.decoder_start_token_id = tokenizer.pad_token_id\n",
    "    generation_config.max_new_tokens = 320\n",
    "    generation_config.repetition_penalty = 1.5\n",
    "    generation_config.num_beams = 1         # greedy search\n",
    "    generation_config.do_sample = False     # greedy search\n",
    "\n",
    "    training_args = Seq2SeqTrainingArguments(\n",
    "        output_dir=config.output_dir,\n",
    "        per_device_train_batch_size=config.batch_size,\n",
    "        auto_find_batch_size=True,  # 防止OOM\n",
    "        gradient_accumulation_steps=config.gradient_accumulation_steps,\n",
    "        learning_rate=config.learning_rate,\n",
    "        logging_steps=config.logging_steps,\n",
    "        num_train_epochs=config.num_train_epochs,\n",
    "        optim=\"adafactor\",\n",
    "        report_to='tensorboard',\n",
    "        log_level='info',\n",
    "        save_steps=config.save_steps,\n",
    "        save_total_limit=3,\n",
    "        fp16=config.fp16,\n",
    "        logging_first_step=config.logging_first_step,\n",
    "        warmup_steps=config.warmup_steps,\n",
    "        seed=config.seed,\n",
    "        generation_config=generation_config,\n",
    "    )\n",
    "\n",
    "    # step 6: init a collator\n",
    "    collator = DataCollatorForSeq2Seq(tokenizer, max_length=config.max_seq_len)\n",
    "    \n",
    "    # Step 7: Define the Trainer\n",
    "    trainer = Seq2SeqTrainer(\n",
    "        model=model,\n",
    "        args=training_args,\n",
    "        train_dataset=dataset,\n",
    "        eval_dataset=dataset,\n",
    "        tokenizer=tokenizer,\n",
    "        data_collator=collator,\n",
    "    )\n",
    "\n",
    "    # step 8: train\n",
    "    trainer.train(\n",
    "        # resume_from_checkpoint=True\n",
    "    )\n",
    "\n",
    "    loss_log = pd.DataFrame(trainer.state.log_history)\n",
    "    loss_log.to_csv(f\"{root}/logs/ie_task_finetune_log_{time.strftime('%Y%m%d-%H%M')}.csv\")\n",
    "\n",
    "    # Step 9: Save the model\n",
    "    trainer.save_model(config.output_dir)\n",
    "\n",
    "    return trainer\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = SFTconfig()\n",
    "config.finetune_from_ckp_file = InferConfig().model_dir\n",
    "config.sft_train_file = './data/my_train.json'\n",
    "config.output_dir = './model_save/ie_task'\n",
    "config.max_seq_len = 512\n",
    "config.batch_size = 16\n",
    "config.gradient_accumulation_steps = 4\n",
    "config.logging_steps = 20\n",
    "config.learning_rate = 5e-5\n",
    "config.num_train_epochs = 6\n",
    "config.save_steps = 3000\n",
    "config.warmup_steps = 1000\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = sft_train(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "root = os.path.realpath('.').replace('\\\\','/').split('/')[0: -2]\n",
    "root = '/'.join(root)\n",
    "if root not in sys.path:\n",
    "     sys.path.append(root)\n",
    "import ujson, torch\n",
    "from rich import progress\n",
    "\n",
    "from model.infer import ChatBot\n",
    "from config import InferConfig\n",
    "from utils.functions import f1_p_r_compute\n",
    "inf_conf = InferConfig()\n",
    "inf_conf.model_dir = './model_save/ie_task/'\n",
    "bot = ChatBot(infer_config=inf_conf)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(傅淑云,民族,汉族),(傅淑云,出生地,上海),(傅淑云,出生日期,1915年)]\n"
     ]
    }
   ],
   "source": [
    "ret = bot.chat('请抽取出给定句子中的所有三元组。给定句子：傅淑云，女，汉族，1915年出生，上海人')\n",
    "print(ret)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('傅淑云', '民族', '汉族'), ('傅淑云', '出生地', '上海'), ('傅淑云', '出生日期', '1915年')]\n"
     ]
    }
   ],
   "source": [
    "def text_to_spo_list(sentence: str) -> str:\n",
    "    '''\n",
    "    将输出转换为SPO列表，时间复杂度： O(n)\n",
    "    '''\n",
    "    spo_list = []\n",
    "    sentence = sentence.replace('，',',').replace('（','(').replace('）', ')') # 符号标准化\n",
    "\n",
    "    cur_txt, cur_spo, started = '',  [], False\n",
    "    for i, char in enumerate(sentence):\n",
    "        if char not in '[](),':\n",
    "            cur_txt += char\n",
    "        elif char == '(':\n",
    "            started = True\n",
    "            cur_txt, cur_spo = '' , []\n",
    "        elif char == ',' and started and len(cur_txt) > 0 and len(cur_spo) < 3:\n",
    "            cur_spo.append(cur_txt)\n",
    "            cur_txt = ''\n",
    "        elif char == ')' and started and len(cur_txt) > 0 and len(cur_spo) == 2:\n",
    "            cur_spo.append(cur_txt)\n",
    "            spo_list.append(tuple(cur_spo))\n",
    "            cur_spo = []\n",
    "            cur_txt = ''\n",
    "            started = False\n",
    "    return spo_list\n",
    "print(text_to_spo_list(ret))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = []\n",
    "with open('./data/test.json', 'r', encoding='utf-8') as f:\n",
    "    test_data = ujson.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'prompt': '请抽取出给定句子中的所有三元组。给定句子：查尔斯·阿兰基斯（charles aránguiz），1989年4月17日出生于智利圣地亚哥，智利职业足球运动员，司职中场，效力于德国足球甲级联赛勒沃库森足球俱乐部',\n",
       "  'response': '[(查尔斯·阿兰基斯,出生地,圣地亚哥),(查尔斯·阿兰基斯,出生日期,1989年4月17日)]'},\n",
       " {'prompt': '请抽取出给定句子中的所有三元组。给定句子：《离开》是由张宇谱曲，演唱',\n",
       "  'response': '[(离开,歌手,张宇),(离开,作曲,张宇)]'}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data[0:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bca40f71fcc34dda95eb97a6f48fea0c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "prompt_buffer, batch_size, n = [], 32, len(test_data)\n",
    "traget_spo_list, predict_spo_list = [], []\n",
    "for i, item in progress.track(enumerate(test_data), total=n):\n",
    "    prompt_buffer.append(item['prompt'])\n",
    "    traget_spo_list.append(\n",
    "        text_to_spo_list(item['response'])\n",
    "    )\n",
    "\n",
    "    if len(prompt_buffer) == batch_size or i == n - 1:\n",
    "        torch.cuda.empty_cache()\n",
    "        model_pred = bot.chat(prompt_buffer)\n",
    "        model_pred = [text_to_spo_list(item) for item in model_pred]\n",
    "        predict_spo_list.extend(model_pred)\n",
    "        prompt_buffer = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[('查尔斯·阿兰基斯', '出生地', '圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')], [('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]] \n",
      "\n",
      "\n",
      " [[('查尔斯·阿兰基斯', '国籍', '智利'), ('查尔斯·阿兰基斯', '出生地', '智利圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')], [('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]]\n"
     ]
    }
   ],
   "source": [
    "print(traget_spo_list[0:2], '\\n\\n\\n',predict_spo_list[0:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21636 21636\n"
     ]
    }
   ],
   "source": [
    "print(len(predict_spo_list), len(traget_spo_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1: 0.74, precision： 0.75, recall: 0.73\n"
     ]
    }
   ],
   "source": [
    "f1, p, r = f1_p_r_compute(predict_spo_list, traget_spo_list)\n",
    "print(f\"f1: {f1:.2f}, precision： {p:.2f}, recall: {r:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['你好，有什么我可以帮你的吗？',\n",
       " '[(江苏省赣榆海洋经济开发区,成立日期,2003年1月28日)]',\n",
       " '南方地区气候干燥，气候寒冷，冬季寒冷，夏季炎热，冬季寒冷的原因很多，可能是由于全球气候变暖导致的。\\n南方气候的变化可以引起天气的变化，例如气温下降、降雨增多、冷空气南下等。南方气候的变化可以促进气候的稳定，有利于经济发展和经济繁荣。\\n此外，南方地区的气候也可能受到自然灾害的影响，例如台风、台风、暴雨等，这些自然灾害会对南方气候产生影响。\\n总之，南方气候的变化是一个复杂的过程，需要综合考虑多方面因素，才能应对。']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试一下对话能力\n",
    "bot.chat(['你好', '请抽取出给定句子中的所有三元组。给定句子：江苏省赣榆海洋经济开发区位于赣榆区青口镇临海而建，2003年1月28日，经江苏省人民政府《关于同意设立赣榆海洋经济开发区的批复》（苏政复〔2003〕14号）文件批准为全省首家省级海洋经济开发区，','如何看待最近南方天气突然变冷？'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
